#----------------------------------------------------------------------
#  GFDM method test - 3d heat equation, Mixed BC
#  Author: Andrea Pavan
#  Date: 22/12/2022
#  License: GPLv3-or-later
#----------------------------------------------------------------------
using ElasticArrays;
using LinearAlgebra;
using SparseArrays;
using Printf;
using PyPlot;
include("utils.jl");


#problem definition
l1 = 5.0;       #domain x size
l2 = 2.5;       #domain y size
l3 = 2.0;       #domain z size
uL = 400;       #left border temperature
uR = 300;       #right border temperature

meshSize = 0.25;        #distance target between internal nodes
surfaceMeshSize = 0.25;        #distance target between boundary nodes
minNeighbors = 25;       #minimum number of neighbors allowed
minSearchRadius = meshSize/2;       #starting search radius


#read pointcloud from a SU2 file
time1 = time();
pointcloud = ElasticArray{Float64}(undef,3,0);      #2xN matrix containing the coordinates [X;Y] of each node
boundaryNodes = Vector{Int}(undef,0);       #indices of the boundary nodes
internalNodes = Vector{Int}(undef,0);       #indices of the internal nodes
normals = ElasticArray{Float64}(undef,3,0);     #2xN matrix containing the components [nx;ny] of the normal of each boundary node

#pointcloud = parseSU2mesh("13b_direct_3d_heat_rod_su2_mesh_3626.su2");
pointcloud = parseSU2mesh("13b_direct_3d_heat_rod_su2_mesh_21850.su2");
N = size(pointcloud,2);

for i=1:N
    if pointcloud[3,i]==0
        push!(boundaryNodes, i);
        append!(normals, [0,0,-1]);
    elseif pointcloud[3,i]==l1
        push!(boundaryNodes, i);
        append!(normals, [0,0,1]);
    elseif pointcloud[1,i]==0 || (abs(pointcloud[1,i])-(l2-l3)/2)^2+(pointcloud[2,i]+l3/2)^2>=(l3/2-1e-3)^2
        relpos = [abs(pointcloud[1,i])-(l2-l3)/2, pointcloud[2,i]+l3/2, 0];
        push!(boundaryNodes, i);
        append!(normals, relpos./sqrt(relpos'relpos));
    else
        push!(internalNodes, i);
    end
end

println("Generated pointcloud in ", round(time()-time1,digits=2), " s");
println("Pointcloud properties:");
println("  Boundary nodes: ",length(boundaryNodes));
println("  Internal nodes: ",length(internalNodes));
println("  Memory: ",memoryUsage(pointcloud,boundaryNodes));

#pointcloud plot
figure();
plot3D(pointcloud[1,internalNodes],pointcloud[2,internalNodes],pointcloud[3,internalNodes],"k.");
#plot3D(pointcloud[1,boundaryNodes],pointcloud[2,boundaryNodes],pointcloud[3,boundaryNodes],"r.");
title("Pointcloud plot");
axis("equal");
display(gcf());


#boundary conditions
N = size(pointcloud,2);     #number of nodes
uD = Vector{Float64}(undef,N);
uN = Vector{Float64}(undef,N);
for i=1:N
    uD[i] = NaN;
    uN[i] = 0.0;
end
for i in boundaryNodes
    if pointcloud[3,i]==0
        #bottom
        uD[i] = uR;
    end
    if pointcloud[3,i]==l1
        #right
        uD[i] = uL;
    end
end


#neighbor search
time2 = time();
(neighbors,Nneighbors,cell) = cartesianNeighborSearch(pointcloud,meshSize,minNeighbors);
#(neighbors,Nneighbors) = quadrantNeighborSearch(pointcloud,meshSize);
println("Found neighbors in ", round(time()-time2,digits=2), " s");
println("Connectivity properties:");
println("  Max neighbors: ",maximum(Nneighbors)," (at index ",findfirst(isequal(maximum(Nneighbors)),Nneighbors),")");
println("  Avg neighbors: ",round(sum(Nneighbors)/length(Nneighbors),digits=2));
println("  Min neighbors: ",minimum(Nneighbors)," (at index ",findfirst(isequal(minimum(Nneighbors)),Nneighbors),")");

#connectivity plot
#=figure(1);
#plot1Idx = rand(1:N,5);
plot1Idx = rand(1+length(boundaryNodes):N,5);
plot3D(pointcloud[1,:],pointcloud[2,:],pointcloud[3,:],marker=".",linestyle="None",color="lightgray");
for i in plot1Idx
    connColor = rand(3);
    plot3D(pointcloud[1,neighbors[i]],pointcloud[2,neighbors[i]],pointcloud[3,neighbors[i]],marker=".",linestyle="None",color=connColor);
    for j in neighbors[i]
        plot3D([pointcloud[1,i],pointcloud[1,j]],[pointcloud[2,i],pointcloud[2,j]],[pointcloud[3,i],pointcloud[3,j]],"-",color=connColor);
    end
end
plot3D(pointcloud[1,plot1Idx],pointcloud[2,plot1Idx],pointcloud[3,plot1Idx],"k.");
title("Connectivity plot");
axis("equal");
display(gcf());=#


#neighbors distances and weights
time3 = time();
P = Vector{Array{Float64}}(undef,N);        #relative positions of the neighbors
r2 = Vector{Vector{Float64}}(undef,N);      #relative distances of the neighbors
w2 = Vector{Vector{Float64}}(undef,N);      #neighbors weights
for i=1:N
    P[i] = Array{Float64}(undef,3,Nneighbors[i]);
    r2[i] = Vector{Float64}(undef,Nneighbors[i]);
    w2[i] = Vector{Float64}(undef,Nneighbors[i]);
    for j=1:Nneighbors[i]
        P[i][:,j] = pointcloud[:,neighbors[i][j]]-pointcloud[:,i];
        r2[i][j] = P[i][:,j]'P[i][:,j];
    end
    r2max = maximum(r2[i]);
    for j=1:Nneighbors[i]
        w2[i][j] = exp(-1*r2[i][j]/r2max)^2;
    end
end
w2pde = 2.0;        #least squares weight for the pde
w2bc = 2.0;     #least squares weight for the boundary condition


#least square matrix inversion
A = Vector{Matrix}(undef,N);        #least-squares matrices
C = Vector{Matrix}(undef,N);        #derivatives coefficients matrices
for i in internalNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    zj = P[i][3,:];
    V = zeros(Float64,1+Nneighbors[i],10);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, xj[j]*yj[j], xj[j]*zj[j], yj[j]*zj[j]];
    end
    V[1+Nneighbors[i],:] = [0, 0, 0, 0, 2, 2, 2, 0, 0, 0];
    W = Diagonal(vcat(w2[i],w2pde));
    A[i] = transpose(V)*W*V;
    (Q,R) = qr(A[i]);
    C[i] = inv(R)*transpose(Q)*transpose(V)*W;
end
for i in boundaryNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    zj = P[i][3,:];
    V = zeros(Float64,2+Nneighbors[i],10);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, xj[j]*yj[j], xj[j]*zj[j], yj[j]*zj[j]];
    end
    V[1+Nneighbors[i],:] = [0, 0, 0, 0, 2, 2, 2, 0, 0, 0];
    if !isnan(uD[i])
        V[2+Nneighbors[i],:] = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0];
    else
        V[2+Nneighbors[i],:] = [0, normals[1,i], normals[2,i], normals[3,i], 0, 0, 0, 0, 0, 0];
    end
    W = Diagonal(vcat(w2[i],w2pde,w2bc));
    A[i] = transpose(V)*W*V;
    (Q,R) = qr(A[i]);
    C[i] = inv(R)*transpose(Q)*transpose(V)*W;
end
println("Inverted least-squares matrices in ", round(time()-time3,digits=2), " s");


#matrix assembly
time4 = time();
rows = Int[];
cols = Int[];
vals = Float64[];
for i=1:N
    push!(rows, i);
    push!(cols, i);
    push!(vals, 1);
    for j=1:Nneighbors[i]
        push!(rows, i);
        push!(cols, neighbors[i][j]);
        push!(vals, -C[i][1,j]);
    end
end
M = sparse(rows,cols,vals,N,N);
println("Completed matrix assembly in ", round(time()-time4,digits=2), " s");


#linear system solution
time5 = time();
b = zeros(N);       #rhs vector
for i in internalNodes
    b[i] = 0;
end
for i in boundaryNodes
    b[i] = 0;
    if !isnan(uD[i])
        b[i] += C[i][1,end]*uD[i];
    else
        b[i] += C[i][1,end]*uN[i];
    end
end
u = M\b;
println("Linear system solved in ", round(time()-time5,digits=2), " s");

#error calculation
ue = @. uR+pointcloud[3,:]*(uL-uR)/l1;
erru = ue-u;        #numerical solution error
maxerru = maximum(abs.(erru));
rmse = sqrt(sum(erru.^2)/length(erru));
println("  Max error: ",maxerru);
println("  RMSE: ",rmse);

#solution plot
figure();
plt = scatter3D(pointcloud[1,internalNodes],pointcloud[2,internalNodes],pointcloud[3,internalNodes],c=u[internalNodes],cmap="inferno");
title("Numerical solution");
axis("equal");
colorbar(plt);
display(gcf());

figure();
plot(pointcloud[3,:],ue,"k.",label="Analytic");
plot(pointcloud[3,:],u,"r.",label="GFDM");
title("Numerical solution u(z)");
legend(loc="upper left");
xlabel("z coordinate");
ylabel("temperature T");
display(gcf());


#error plot
figure();
plt = scatter3D(pointcloud[1,internalNodes],pointcloud[2,internalNodes],pointcloud[3,internalNodes],c=erru[internalNodes],cmap="inferno");
title("Solution error");
axis("equal");
colorbar(plt);
display(gcf());
